# coding:utf-8
'''
author:wangyi
'''
'''
textcnn model
1.embeddings
2.conv(inception)
3.max pool 
4.fully-connection (only one)
5.classfication 
'''
import tensorflow as tf

class TextCNN:

    def __init__(self,embedding_dim,vocab_size,max_seq_len,filters,filter_nums,fc_dim,class_nums):
        '''
        初始化函数
        :param embedding_dim: 
        :param vocab_size: 
        :param max_seq_len: 
        :param filters: 
        :param filter_nums: 
        :param fc_dim: 
        :param class_nums: 
        '''
        # 词向量维度
        self.embedding_dim = embedding_dim
        # 词表大小
        self.vocab_size = vocab_size
        # 句子最大长度
        self.max_seq_len = max_seq_len
        # 卷积核尺寸 [2,3,4]
        self.filters = filters
        # 卷积核数目
        self.filter_nums = filter_nums
        # 全连接神经元数
        self.fc_dim = fc_dim
        # 类别数
        self.class_nums = class_nums

        '''
        embeddings --> (conv -->max pool) -->(fc-->class)
        '''
        # 词表矩阵 [vocab_size,embedding_dim]
        self.embeddings = tf.Variable(tf.truncated_normal(shape=[self.vocab_size,
                                                                 self.embedding_dim],stddev=0.1))
        # 输入 [1,3,4,5,6,...] [batch_size,max_seq_len]
        self.input_x = tf.placeholder(dtype=tf.int32,shape=[None,self.max_seq_len])
        # 转化成待卷积的文本矩阵 [batch_size,max_seq_len,embeding_dim]
        self.input = tf.nn.embedding_lookup(self.embeddings,self.input_x)
        # 增加一个维度在最后，使得输入矩阵变成二维卷积可以操作的矩阵
        # [batch_size,max_seq_len,embeding_dim,1]
        self.input = tf.expand_dims(self.input,-1)
        # 真实标签
        self.input_y = tf.placeholder(tf.float32,shape=[None,self.class_nums])


    def cnn(self): # self.cnn()

        self.pools = []
        for filter in self.filters:
            # 卷积核 [filter,embedding_dim,1,filter_nums]
            F = tf.Variable(tf.truncated_normal(shape=[filter,self.embedding_dim,
                                                       1,self.filter_nums]))
            B = tf.Variable(tf.constant([0.1],shape=[self.filter_nums]))

            # 卷积操作 [batch_size,max_seq_len-filter+1,
            # embedding_dim-embedding_dim+1,1-1+1,filter_nums]
            #  [batch_size,max_seq_len-filter+1,1,1,filter_nums]
            conv = tf.nn.conv2d(self.input,F,strides=[1,1,1,1],padding='VALID')+B
            # 池化 [batch_size,1,1,filter_nums]
            conv = tf.nn.max_pool(conv,ksize=[1,self.max_seq_len-filter+1,1,1],
                                  strides=[1,1,1,1],padding='VALID')
            self.pools.append(conv)

    # pools [[batch_size,1,1,filter_nums],[batch_size,1,1,filter_nums]]
    def fully_connections(self):

        # concat([x1,x2])
        # [batch_size, 1, 1, len(self.filters)*filter_nums]
        h = tf.concat(self.pools,axis=3)
        # [batch_size, len(self.filters)*filter_nums]
        h = tf.reshape(h,shape=[-1,self.filter_nums*len(self.filters)])

        Fw = tf.Variable(tf.truncated_normal(shape=[self.filter_nums*len(self.filters),
                                                   self.class_nums],stddev=0.1))
        Fb = tf.Variable(tf.constant([0.0],shape=[self.class_nums]))

        self.logits = tf.matmul(h,Fw)+Fb

    # 反向过程
    def op(self,learning_rate):

        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.input_y,logits=self.logits))
        self.train_op = tf.train.AdamOptimizer(learning_rate=learning_rate)








